Skip to content

Demo of compile-time rule generation#593

Closed
MasonProtter wants to merge 1 commit intochalk-lab:mainfrom
MasonProtter:generated_rule_building
Closed

Demo of compile-time rule generation#593
MasonProtter wants to merge 1 commit intochalk-lab:mainfrom
MasonProtter:generated_rule_building

Conversation

@MasonProtter
Copy link
Collaborator

This is a little demo followup to #498

I made a generated function Mooncake.generated_build_rrule that is able to create an rrule at compile time and return it as a compiletime const. Here's a little demo of it in action:

julia> using Mooncake

julia> @generated function foo(f, args...)
           # I'm just wrapping this in a generated function to more easily show when it gets recompiled
           Core.println("(re-)compiling!")
           quote
               rule = Mooncake.generated_build_rrule(Tuple{typeof(f), typeof.(args)...})
           end
       end;
julia> f(x) = sin(g(x - 1.0));

julia> g(x) = x^2;

julia> foo(f, 1.0)
(re-)compiling!
Mooncake.DerivedRule{Tuple{typeof(f), Float64}, Tuple{Mooncake.CoDual{typeof(f), Mooncake.NoFData}, Mooncake.CoDual{Float64, Mooncake.NoFData}}, Mooncake.CoDual{Float64, Mooncake.NoFData}, Tuple{Float64}, Tuple{Mooncake.NoRData, Float64}, false, Val{2}}(MistyClosure (::Mooncake.CoDual{typeof(f), Mooncake.NoFData}, ::Mooncake.CoDual{Float64, Mooncake.NoFData})::Mooncake.CoDual{Float64, Mooncake.NoFData}->◌, Base.RefValue{MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Float64}, Tuple{Mooncake.NoRData, Float64}}}}(MistyClosure (::Float64)::Tuple{Mooncake.NoRData, Float64}->◌), Val{2}())

This thing only has the overhead of passing around a struct:

julia> @btime foo(f, 1.0);
  2.224 ns (0 allocations: 0 bytes)

And it carries backedges to f so that if a relevant methodistance is invalidated, then the rrule is re-compiled:

julia> foo(f, 1.0); # no message becuase it's already recompiled

julia> g(x) = x + 1; # invalidate f by changing g

julia> foo(f, 1.0);
(re-)compiling!

@MasonProtter MasonProtter force-pushed the generated_rule_building branch from d9a5d06 to fb1b880 Compare June 14, 2025 16:46
@willtebbutt
Copy link
Collaborator

willtebbutt commented Jun 15, 2025

This is very cool @MasonProtter . Before reviewing, I just want to make sure that I've correctly understood what you've done in this PR. Is the below roughly correct?

"This PR introduces a generated function which returns the DerivedRule associated to a given signature. It adds backedges in the generated function to make sure that recompilation happens if a method is added in a later world age which would cause the rule to need to be re-derived. The advantage of this general approach is that the time it takes to get hold of a rule for a given signature is greatly reduced."

@MasonProtter
Copy link
Collaborator Author

MasonProtter commented Jun 15, 2025

Yes @willtebbutt that's correct. However, there's also a catch with this approach, which gives it some downsides: it seems like an opaque closure created in a generated function cannot be inlined, which means that you hit allocations from actually calling this rule.

I have a branch where I did did this for the forward diff stuff, and the effect is easier to see there:

julia> function foo(f, x)
           rule = Mooncake.generated_build_frule(Tuple{typeof(f), Float64})
           rule(Dual(f, Mooncake.NoTangent()), Dual(x, 1.0))
       end;

julia> @btime foo(x -> sin(x + 1), 1.0)
  39.581 ns (3 allocations: 96 bytes)
Dual{Float64, Float64}(0.9092974268256817, -0.4161468365471424)

whereas if we compute the rule 'normally' the compiler is allowed to call it without overhead:

julia> function bar(rule, f, x)
           rule(Dual(f, Mooncake.NoTangent()), Dual(x, 1.0))
       end;

julia> let f = x -> sin(x + 1)
           rule = Mooncake.build_frule(Mooncake.get_interpreter(), Tuple{typeof(f), Float64})
           @btime bar($rule, $f, 1.0)
       end
  10.881 ns (0 allocations: 0 bytes)
Dual{Float64, Float64}(0.9092974268256817, -0.4161468365471424)

I'll investigate a bit at some point if this is possible to overcome.

@willtebbutt
Copy link
Collaborator

Great, glad I've understood. Regarding the performance: it seems odd that this introduces allocations. Typically all of the operations associated to passing OpaqueClosures around and calling them are allocation-free, so I feel like it ought to be possible to get rid of these allocations.

@MasonProtter
Copy link
Collaborator Author

Passing around opaque closures is cheap when the compiler is allowed to inline them, but if it's not, then they're expensive. Here's an example of that:

julia> function make_oc()
           Base.Experimental.@opaque (x::Int) -> sin(x - 1)
       end;

julia> @generated function make_oc_gen()
           Base.Experimental.@opaque (x::Int) -> sin(x - 1)
       end;

julia> @btime f(x) setup=begin
           x = 1
           f = make_oc()
       end
  3.566 ns (0 allocations: 0 bytes)
0.0

julia> @btime f(x) setup=begin
           x = 1
           f = make_oc_gen()
       end
  38.581 ns (2 allocations: 64 bytes)
0.0

Copy link
Collaborator

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some thoughts on design -- will review for style etc once we're all happy with this. Thanks again for opening this PR, I'm excited about it.

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
if world > interp.world
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, we should probably tweak the way that Mooncake handles world ages generally. The current approach is something of a hack (notice, for example, that DynamicDerivedRules don't accept a MooncakeInterpreter, and therefore don't have a pinned world age -- they just always grab a new interpreter on-the-fly, and use it to derive a rule). We can worry about that a bit later though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that'd be a good idea. I think it'd also potentially be a good idea to keep around a dictionary of all the different interpreters from different worlds, rather than just the latest interpreter. Not sure how you feel about that though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea -- what do you have in mind? Make get_intepreter accept a world age + make the constant containing the interpreter a dictionary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exactly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be great. My thinking is that we should wind up removing the above check entirely, and just rely on correctly propagating the world age through to everything via the interpreter.

Is this something you would be interested in doing as part of this PR, or would you rather it were done as part of a separate one?

interp = MooncakeInterpreter(DefaultCtx; world)
rule = build_rrule(interp, sig; world)

ci = expr_to_codeinfo(@__MODULE__(), [Symbol("#self#"), :sig], [], (), :(return $rule))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably true that just returning this is unsafe because the DerivedRule contains various bits of state which get read from / written to in specific orders during the forwards- and reverse passes. The caching functionality in build_rrule therefore makes a call to _copy, which ensures that each time we ask for a rule, it gets fresh state which is definitely not shared with any other rule for this signature. Ensuring that no two instances of a rule share the same state ensures that we avoid race conditions if we're running the same rule on multiple threads, and makes recursion work correctly (the way that we handle recursion is a bit weird).

Anyway, we'll have to do something similar here. You could simply replace the :(return $rule) with :(return _copy($rule). This will obviously have worse performance (not allocation-free), but is probably going to be better than what we have at the minute because it will still avoid the type unstable code here).

In order to do better, we would need to avoid _copy altogether. I suspect that we'll have to do some clever caching here to make this work well. I suspect that it can be done, but maybe it's something to do in a follow-up PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, makes sense. I didn't know there was state inside the rule as well, that's unfortunate.

@vtjnash
Copy link

vtjnash commented Jun 16, 2025

You likely need to add __precompile__(false) statement somewhere to prohibit this package from being accessed via precompile. I’m often asked what things @generated is not allowed to do, and this is one of those things it is not allowed to do: using the precompile cache if any of its generated functions may return an OpaqueClosure method. Without that annotation, the package may cause incorrect or bad inference results which may poison any downstream consumer packages. This won’t generally be visible in tests, since tests don’t typically define downstream consumer packages, but it would be potentially a soundness and performance risk for downstream users if Mooncake fails to add that directive when making this addition of a generated function.

@willtebbutt
Copy link
Collaborator

Thanks for letting us know! I would certainly not have thought about this.

@MasonProtter
Copy link
Collaborator Author

MasonProtter commented Jun 16, 2025

So we'll likely need a different approach then rather than disabling precompilation (unless we stick the compile-time-generated version in a separate module that has precompilation disabled? But that's a pretty unsatisfying solution and probably quite brittle).

An alternative route we could take here would be to wait for v1.12, and use the mechanism introduced in JuliaLang/julia#56660 to try and do this. Of course, if JuliaLang/julia#56650 were to be resurrected that'd be much better for Mooncake. It seems like @Keno is still somewhat open to that design and having it complement 56660.

Maybe we can discuss with him sometime.

@yebai
Copy link
Member

yebai commented Jan 9, 2026

Closed in favour of #900

@yebai yebai closed this Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants